'''
Created on May 17, 2018

@author: helrewaidy
'''
# models

import argparse

import torch
import numpy as np
import os

########################## Initializations ########################################
model_names = 'recoNet_Model1'
parser = argparse.ArgumentParser(description='PyTorch MD-CNN Training')
# parser.add_argument('data', metavar='DIR',
#                     help='path to dataset')
parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet18',
                    choices=model_names,
                    help='model architecture: ' +
                         ' | '.join(model_names) +
                         ' (default: resnet18)')
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
                    help='number of data loading workers (default: 4)')
parser.add_argument('--epochs', default=90, type=int, metavar='N',
                    help='number of total epochs to run')
parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
                    help='manual epoch number (useful on restarts)')
parser.add_argument('-b', '--batch-size', default=256, type=int,
                    metavar='N', help='mini-batch size (default: 256)')
parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
                    metavar='LR', help='initial learning rate')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
                    help='momentum')
parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float,
                    metavar='W', help='weight decay (default: 1e-4)')
parser.add_argument('--print-freq', '-p', default=10, type=int,
                    metavar='N', help='print frequency (default: 10)')
parser.add_argument('--resume', default='', type=str, metavar='PATH',
                    help='path to latest checkpoint (default: none)')
parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
                    help='evaluate model on validation set')
parser.add_argument('--pretrained', dest='pretrained', action='store_true',
                    help='use pre-trained model')
parser.add_argument('--world-size', default=1, type=int,
                    help='number of distributed processes')
parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str,
                    help='url used to set up distributed training')
parser.add_argument('--dist-backend', default='gloo', type=str,
                    help='distributed backend')
parser.add_argument('--cpu', '-c', action='store_true',
                    help='Do not use the cuda version of the net',
                    default=False)
parser.add_argument('--viz', '-v', action='store_true',
                    help='Visualize the images as they are processed',
                    default=False)
parser.add_argument('--no-save', '-n', action='store_false',
                    help='Do not save the output masks',
                    default=False)
parser.add_argument('--model', '-m', default='MODEL_EPOCH417.pth',
                    metavar='FILE',
                    help='Specify the file in which is stored the model'
                         " (default : 'MODEL.pth')")
###################################################################

os.environ["CUDA_VISIBLE_DEVICES"] = '0,1,2,3'

class Parameters():
    def __init__(self):
        super(Parameters, self).__init__()

        ## Hardware/GPU parameters =================================================
        self.Op_Node = 'spider'  # 'alpha_V12' # 'myPC', 'O2', 'spider'
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.tbVisualize = False
        self.tbVisualize_kernels = False
        self.tbVisualize_featuremaps = False
        self.multi_GPU = True

        if self.Op_Node in ['myPC', 'alpha_V12']:
            self.device_ids = [0]
        elif self.Op_Node in ['spider', 'O2']:
            self.device_ids = range(0, torch.cuda.device_count())

        if self.Op_Node in ['spider', 'O2', 'alpha_V12']:
            self.data_loders_num_workers = 1
        else:
            self.data_loders_num_workers = 4

        ## Network/Model parameters =================================================
        self.network_type = '2D'
        self.num_slices_3D = 7
        if self.Op_Node in ['myPC', 'alpha_V12']:
            self.batch_size = 3
        elif self.Op_Node in ['spider', 'O2']:
            self.batch_size = 15 # * len(self.device_ids) // 8 // (self.num_slices_3D if self.network_type == '3D' else 1)

        print('-- # GPUs: ', len(self.device_ids))
        print('-- batch_size: ', self.batch_size)
        self.args = parser.parse_args()

        self.activation_func = 'CReLU'  # 'CReLU' 'CLeakyeak' # 'modReLU' 'KAF2D' 'ZReLU'
        self.args.lr = 0.001
        self.dropout_ratio = 0.0
        self.epochs = 550
        self.training_percent = 0.8
        self.nIterations = 1
        self.magnitude_only = False
        self.Validation_Only = False
        self.Evaluation = False

        self.MODEL = 9  # Complex Network takes neighborhood matrix input and image domain output
        self.complex_net = True

        ## Dataset and paths =================================================
        self.g_methods = ['pyNUFFT', 'BART', 'python_interp']
        self.gridding_method = self.g_methods[1]

        self.gd_methods = ['RING', 'AC-ADDAPTIVE', 'NONE']
        self.gradient_delays_method = ''  # self.gd_methods[0]
        self.rot_angle = True

        self.ds_total_num_slices = 0
        self.patients = []
        self.num_phases = 25
        self.radial_cine = True
        self.n_spokes = 14  # 16 #20 #33
        self.Rate = np.round(198 / self.n_spokes) if self.radial_cine else 3
        self.input_slices = list()
        self.num_slices_per_patient = list()
        self.groundTruth_slices = list()
        self.training_patients_index = list()
        self.us_rates = list()
        self.saveVolumeData = False
        self.multiCoilInput = True
        self.coilCombinedInputTV = True
        self.moving_window_size = 7

        if self.network_type == '2D':
            self.img_size = [208, 208]
        else:
            self.img_size = [208, 208, self.moving_window_size]  # [50, 50, 20]  # 64, 256, 320

        if self.multiCoilInput:
            self.n_channels = 8
        else:
            self.n_channels = 1

        self.cropped_dataset64 = False

        self.trialNum = 'MultiDomainCNN'
        self.arch_name = 'Model_0' + str(self.MODEL) + '_R' + str(self.Rate) + 'Trial' + self.trialNum

        self.dir = {'/data2/helrewaidy/cine_recon/dataset/'}
        self.model_save_dir = '/data2/helrewaidy/cine_recon/models/' + self.arch_name + '/'
        self.net_save_dir = '/data2/helrewaidy/cine_recon/matlab_workspace/'
        self.tensorboard_dir = '/data2/helrewaidy/cine_recon/models/' + self.arch_name + '_tensorboard/'

        self.args.model = self.model_save_dir + 'MODEL_EPOCH.pth'














